import numpy as np
from scipy.optimize import minimize
import pandas as pd
from IPython.display import clear_output
from scipy.optimize import fsolve
import time
import matplotlib.pyplot as plt
import networkx as nx
import copy



def Obj_Functional_Iterations(Functional_estimations, FS, LW, MS, MEW, LS):
    
    """
    This function plots how the value of the objective functional changed across FBS method iterations (or, alternatively,
    across estimations of the control variable)
    Two figures are drawn - one for the integral term and one for the terminal term
    
    Input:
    Functional_estimations - values of the objective functional for estimations of the FBS method
    FS                     - font size
    LW                     - line width
    MS                     - marker size
    MEW                    - marker edge width
    LS                     - size of tick params
    """
    
    N_Iter = len(Functional_estimations)
    
    #---------------------------------------------------------------------------
    
    if N_Iter < 5:
    
        print('')
        print('----------------------------------------------------')
        print('')

        for i in Functional_estimations:

            print('Integral:', round(i[0], 3), 'terminal:', round(i[1], 3))
            print('') 
            
    else:
        
        pass
        
    #---------------------------------------------------------------------------
        
    fig = plt.figure(figsize=(5, 5))
    plt.plot([Functional_estimations[i][0] for i in range(len(Functional_estimations))], 
             color='k', lw=LW, 
             marker='d', ms=MS, mec='darkgray', mew=MEW, 
            )
    plt.tick_params(axis='both', labelsize=LS)
    
    if N_Iter <= 5:
        plt.xticks(np.arange(len(Functional_estimations)), np.arange(len(Functional_estimations))+1)
    else:
        Ticks = np.array([0, int(N_Iter/2), N_Iter-1])
        plt.xticks(Ticks, Ticks+1)
    plt.xlabel('iteration', fontsize=FS)
    plt.ylabel('integral term', fontsize=FS)
    plt.show()
    
    #---------------------------------------------------------------------------
    
    fig = plt.figure(figsize=(5, 5))
    plt.plot([Functional_estimations[i][1] for i in range(len(Functional_estimations))], 
             color='k', lw=LW, 
             marker='o', ms=MS, mec='darkgray', mew=MEW, 
            )
    plt.tick_params(axis='both', labelsize=LS)
    if N_Iter <= 5:
        plt.xticks(np.arange(len(Functional_estimations)), np.arange(len(Functional_estimations))+1)
    else:
        Ticks = np.array([0, int(N_Iter/2), N_Iter-1])
        plt.xticks(Ticks, Ticks+1)
    plt.xlabel('iteration', fontsize=FS)
    plt.ylabel('terminal term', fontsize=FS)
    plt.show()
    
    return 0



def Plot_Dynamics(y_t, LW, FS, Arguments):
    
    """
    This function plots dynamics of state variables 
    
    Input:
    y_t        - the array of state matrices (y (\tau) in the main manuscript)
    LW         - line width
    FS         - font size
    Arguments  - hyperparameters
    """

    m = Arguments['m']
    
    for i in range(len(y_t)):

        y_t[i] = list(y_t[i].sum(axis=1))

    y_t = np.array(y_t).T

    #print(y_t.shape)

    #---------------------------------------------------------------------------

    plt.figure(figsize=(15, 5))

    #---------------------------------------------------------------------------

    if m == 3:
        Colors = ['darkorange', 'darkgray', 'darkviolet']
    elif m == 5:
        Colors = ['red', 'darkorange', 'darkgray', 'royalblue', 'darkviolet']

    #---------------------------------------------------------------------------


    for i in range(m):

        plt.plot(y_t[i], lw=LW, color=Colors[i], label=f'$y_{i+1}$')

        #y_t[i] = y_t[i].reshape(m*M,)

        plt.legend(fontsize=FS) 

    #print(0, y_t.shape[1])
    plt.xlim([0, y_t.shape[1]-1])
    
    plt.xlabel(r'$\tau$', fontsize=FS)

    plt.show()
    
    return 0


    
def Show_Control(u_t):
    
    """
    This function shows the control function in convenient form
    
    Input:
    u_t               - the array of control matrices (u (\tau) from the main manuscript)
    """
    
    u_t = np.round(u_t, 3)  # discard noise
    
    Possible_Controls = []

    for i in range(len(u_t)):

        Control = u_t[i]

        counter = 0

        for j in Possible_Controls:

            if (Control != j).sum() > 0:

                counter = counter + 1

        if counter == len(Possible_Controls):  # this element differs from all the elements of the array Possible_Controls

            Possible_Controls.append(Control)

        else:

            pass

    
    #print(Possible_Controls)
    
    Indices = [[] for i in range(len(Possible_Controls))]   


    for i in range(len(Possible_Controls)):  # we do not need the last control 

        Control = Possible_Controls[i]
        
        print(Control)
        print('')
        
        for j in range(len(u_t)):
            
            if (Control != u_t[j]).sum() == 0:
                Indices[i].append(j)
            else:
                pass
            
        print(Indices[i])
        
        print('')
        print('------------------------------------------')
        print('')
    
    
    return 0



def Define_Arguments(M, m, n, Transition_Matrices, Step, Grid, Tau_0, Tau_1, w, v, n_types):
    
    """
    This function converts all the hyperparameters into a dictionary
    
    Output:
    Arguments - dictionary that stores all the arguments
    """
    
    Arguments = {}
    
    #Arguments['A'] = A
    #Arguments['b'] = b
    #Arguments['c'] = c
    #Arguments['s'] = s

    Arguments['M'] = M

    Arguments['m'] = m
    Arguments['n'] = n

    Arguments['Transition_Matrices'] = Transition_Matrices

    Arguments['Step'] = Step
    Arguments['Grid'] = Grid

    Arguments['Tau_0'] = Tau_0
    Arguments['Tau_1'] = Tau_1

    Arguments['w'] = w
    Arguments['v'] = v

    Arguments['n_types'] = n_types

    return Arguments



